from utils.utils_keras import *
from model.KerasLstm import MyLSTM
from model.LstmCap import LstmCap
import os
import tensorflow as tf
import keras.backend.tensorflow_backend as KTF
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


config = tf.ConfigProto()
config.gpu_options.allow_growth=True   #不全部占满显存, 按需分配
sess = tf.Session(config=config)
KTF.set_session(sess)

def main():
    dataset = getDataset()
    lstm_model = MyLSTM(emb_size=200)
    lstm_cap = LstmCap()
    print(len(dataset))
    lstm_cap.train(dataset)

if __name__ == '__main__':
    main()